from typing import List, Iterator, Dict, Tuple, Any

import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.nn import Module

from allennlp.data import Vocabulary
from allennlp.models import Model
from allennlp.modules.text_field_embedders import BasicTextFieldEmbedder
from allennlp.nn.util import get_text_field_mask, move_to_device
from allennlp.training.metrics import CategoricalAccuracy

from config import Config
from .encoder.encoder_factory import EncoderFactory
from .embedding.embedder_factory import EmbedderFactory


class TextClassifier(Model):
    def __init__(self,
                 cf: Config,
                 vocab: Vocabulary,
                 encoder_type=None,
                 token_type=None):
        super(TextClassifier, self).__init__(vocab)

        if encoder_type is None or token_type is None:
            encoder_type = cf.encoder
            token_type = cf.token
        self.embeddings = BasicTextFieldEmbedder(cf.token_config(token_type)['embedder'](vocab))

        self.encoder = EncoderFactory(encoder_type).get_encoder(input_dim=self.embeddings.get_output_dim(),
                                                                config=cf.encoder_config[encoder_type])
        self.hidden2tag = nn.Linear(in_features=self.encoder.get_output_dim(),
                                    out_features=vocab.get_vocab_size('labels'))
        self.dropout = nn.Dropout(p=0.3)
        self.loss = nn.NLLLoss(reduction='none')
        self.accuracy = CategoricalAccuracy()

    def forward(self,
                sentence: Dict[str, torch.Tensor],
                label: torch.Tensor = None):

        mask = get_text_field_mask(sentence)
        embeddings = self.embeddings(sentence)
        embeddings = self.dropout(embeddings)

        encoder_out = self.encoder(embeddings, mask)
        encoder_out = self.dropout(encoder_out)

        logits = self.hidden2tag(encoder_out)
        self.accuracy(logits, label)

        return self.warp_outputs(logits, label)

    def warp_outputs(self, logits, label):
        output = dict()
        output['logits'] = logits
        output['pred'] = logits.argmax(-1)
        if label is not None:
            loss = self.loss(F.log_softmax(logits, -1), label)
            output['single_loss'] = loss
            output['gold'] = label
            output['gold_prob'] = torch.exp(-loss)
            output['loss'] = loss.mean()
        return output

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        return {"accuracy": self.accuracy.get_metric(reset)}
